{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Preparing Environment"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[![Google Colab](https://badgen.net/badge/Lancer/run%20Google%20Colab/orange?icon=terminal)](https://colab.research.google.com/github/BrikerMan/Kashgari/blob/kashgari2/examples/train_with_generator.ipynb)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "from __future__ import absolute_import, division, print_function, unicode_literals\n",
    "\n",
    "try:\n",
    "  # %tensorflow_version only exists in Colab.\n",
    "  %tensorflow_version 2.x\n",
    "except Exception:\n",
    "  pass\n",
    "\n",
    "import tensorflow as tf\n",
    "print(tf.__version__)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install git+https://github.com/BrikerMan/Kashgari.git@kashgari2\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Download Corpus"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tensorflow.keras.utils import get_file\n",
    "\n",
    "def download_data(duplicate=1000):\n",
    "    url_list = [\n",
    "        'https://raw.githubusercontent.com/BrikerMan/JointSLU/master/data/atis-2.train.w-intent.iob',\n",
    "        'https://raw.githubusercontent.com/BrikerMan/JointSLU/master/data/atis-2.dev.w-intent.iob',\n",
    "        'https://raw.githubusercontent.com/BrikerMan/JointSLU/master/data/atis.test.w-intent.iob',\n",
    "        'https://raw.githubusercontent.com/BrikerMan/JointSLU/master/data/atis.train.w-intent.iob'\n",
    "    ]\n",
    "    files = []\n",
    "    for url in url_list:\n",
    "        files.append(get_file(url.split('/')[-1], url))\n",
    "\n",
    "    return files * duplicate\n",
    "\n",
    "corpus_files = download_data()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Run Classification"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from kashgari.generators import ABCGenerator\n",
    "\n",
    "# Define you classification data generator\n",
    "class ClassificationGenerator:\n",
    "    def __init__(self, files):\n",
    "        self.files = files\n",
    "        self._line_count = sum(sum(1 for line in open(file, 'r')) for file in files)\n",
    "\n",
    "    @property\n",
    "    def steps(self) -> int:\n",
    "        return self._line_count\n",
    "\n",
    "    def __iter__(self):\n",
    "        for file in self.files:\n",
    "            with open(file, 'r') as f:\n",
    "                for line in f:\n",
    "                    rows = line.split('\\t')\n",
    "                    x = rows[0].strip().split(' ')[1:-1]\n",
    "                    y = rows[1].strip().split(' ')[-1]\n",
    "                    yield x, y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from kashgari.tasks.classification import BiGRU_Model\n",
    "files = download_data()\n",
    "gen = ClassificationGenerator(files)\n",
    "\n",
    "model = BiGRU_Model()\n",
    "model.fit_generator(gen)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Run Ner"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LabelingGenerator(ABCGenerator):\n",
    "    def __init__(self, files):\n",
    "        self.files = files\n",
    "        self._line_count = sum(sum(1 for line in open(file, 'r')) for file in files)\n",
    "\n",
    "    @property\n",
    "    def steps(self) -> int:\n",
    "        return self._line_count\n",
    "\n",
    "    def __iter__(self):\n",
    "        for file in self.files:\n",
    "            with open(file, 'r') as f:\n",
    "                for line in f:\n",
    "                    rows = line.split('\\t')\n",
    "                    x = rows[0].strip().split(' ')[1:-1]\n",
    "                    y = rows[1].strip().split(' ')[1:-1]\n",
    "                    yield x, y\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from kashgari.tasks.labeling import BiGRU_Model\n",
    "files = download_data()\n",
    "gen = LabelingGenerator(files)\n",
    "\n",
    "model = BiGRU_Model()\n",
    "model.fit_generator(gen)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  },
  "pycharm": {
   "stem_cell": {
    "cell_type": "raw",
    "source": [],
    "metadata": {
     "collapsed": false
    }
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
